{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 链式法则\n",
    "如果某个函数由复合函数表示，则该复合函数的导数可以用构成复合函数的各个函数的导数的乘积表示"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from layer import MulLayer, AddLayer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "220.00000000000003\n",
      "2.2 110.00000000000001 200\n"
     ]
    }
   ],
   "source": [
    "# 乘法层的前向传播和反向传播,购买2个苹果\n",
    "apple = 100 # 苹果单价\n",
    "apple_num = 2 # 个数\n",
    "tax = 1.1 # 税\n",
    "# 乘法运算\n",
    "mul_apple_layer = MulLayer()\n",
    "mul_tax_layer = MulLayer()\n",
    "# 前向传播\n",
    "apple_price = mul_apple_layer.forward(apple, apple_num)\n",
    "price = mul_tax_layer.forward(apple_price, tax)\n",
    "print(price)\n",
    "# 反向传播\n",
    "dprice = 1\n",
    "dapple_price, dtax = mul_tax_layer.backward(dprice)\n",
    "dapple, dapple_num = mul_apple_layer.backward(dapple_price)\n",
    "print(dapple, dapple_num, dtax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "200 450 650 715.0000000000001\n",
      "110.00000000000001 2.2 3.3000000000000003 165.0 650\n"
     ]
    }
   ],
   "source": [
    "# 购买2个苹果和3个橘子\n",
    "apple = 100  # 苹果单价\n",
    "apple_num = 2  # 苹果个数\n",
    "orange = 150  # 橘子单价\n",
    "orange_num = 3  # 橘子个数\n",
    "tax = 1.1  # 税\n",
    "\n",
    "# layer\n",
    "mul_apple_layer = MulLayer()\n",
    "mul_orange_layer = MulLayer()\n",
    "add_apple_orange_layer = AddLayer()\n",
    "mul_tax_layer = MulLayer()\n",
    "\n",
    "# forward\n",
    "apple_price = mul_apple_layer.forward(apple, apple_num)\n",
    "orange_price = mul_orange_layer.forward(orange, orange_num)\n",
    "all_price = add_apple_orange_layer.forward(apple_price, orange_price)\n",
    "price = mul_tax_layer.forward(all_price, tax)\n",
    "print(apple_price, orange_price, all_price, price)\n",
    "\n",
    "# backward\n",
    "dprice = 1\n",
    "dall_price, dtax = mul_tax_layer.backward(dprice)\n",
    "dapple_price, dorange_price = add_apple_orange_layer.backward(dall_price)\n",
    "dorange, dorange_num = mul_orange_layer.backward(dorange_price)\n",
    "dapple, dapple_num = mul_apple_layer.backward(dapple_price)\n",
    "print(dapple_num, dapple, dorange, dorange_num, dtax)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.13 ('d2l')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "6bdcd3dff61cfecb093f68bfbd67338d4a9739a04d65268bf3b66287b86a0a9e"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
